diff --git a/train.py b/train.py
index 6e3b058..8c61ed4 100755
--- a/train.py
+++ b/train.py
@@ -58,6 +58,9 @@ try:
 except ImportError: 
     has_wandb = False
 
+import sky_callback
+from sky_callback import step_iterator
+
 torch.backends.cudnn.benchmark = True
 _logger = logging.getLogger('train')
 
@@ -609,6 +612,11 @@ def main():
         with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
             f.write(args_text)
 
+    sky_callback.init(
+        global_rank=args.rank,
+        total_steps=num_epochs * len(loader_train),
+    )
+
     try:
         for epoch in range(start_epoch, num_epochs):
             if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
@@ -674,7 +682,7 @@ def train_one_epoch(
     end = time.time()
     last_idx = len(loader) - 1
     num_updates = epoch * len(loader)
-    for batch_idx, (input, target) in enumerate(loader):
+    for batch_idx, (input, target) in step_iterator(enumerate(loader)):
         last_batch = batch_idx == last_idx
         data_time_m.update(time.time() - end)
         if not args.prefetcher:
